前情提要: 昨天基本上已經把大部分的code都寫完了,應該可以感受到lightning的簡潔。
在training_step跟validation_step有用到self.loss_fn主要就是loss function,如果是簡單的直接呼叫nn裡面的,那就寫成一行就好,但如果是比較複雜的,我自己會再創一個loss_function.py,然後再用import的方式帶進來。
self.loss_fn = nn.CrossEntropyLoss() #寫在__init__
以下一些設定會在明天做說明,可以先照抄,讓我們先把訓練跑起來。
def main(task, max_epochs):
model = example()
if task == 'train':
callbacks = []
dirpath = "./model"
checkpoint_acc = ModelCheckpoint(
save_top_k = 5,
monitor = "valid_acc_epoch",
mode = "max",
dirpath = dirpath,
filename = "model_{epoch:02d}_{valid_acc_epoch:.2f}",
save_last = True,
)
callbacks.append(checkpoint_acc)
trainer = pl.Trainer(
max_epochs = max_epochs,
callbacks = callbacks,
gradient_clip_val = 2.0,
# devices = [0, 1] # 多GPU訓練
)
trainer.fit(model=model)
# elif task == 'ft':
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = "Training script")
parser.add_argument('--task', type = str, default = 'train')
parser.add_argument('--max_epochs', type = int, default = 50)
args = parser.parse_args()
main(args.task, args.max_epochs)
就下來就可以執行囉,自己在前幾篇有筆誤,Dataset的__len__,以及在model當中多加softmax,如果你是從前兩天一路慢慢打程式的,深感抱歉,目前已經更新,其實很多bug是只有在run的時候才會知道,有時候是手殘,有時是眼殘,但也有蠻多是邏輯有錯,有些東西是沒有考慮到的,這些就只能慢慢debug增加實力。
這裡可以看到,你執行的時候會列印出哪一些裝置可取得,model總共多少參數,有時候實作論文的時候,看參數對不對十分重要!!
然後在你的目錄底下多了一個lightning_logs,這個就是我一直說的tensorboard,我們進到此目錄,輸入以下指令,然後你就可以開啟chrome輸入網址: :,後面port看你當初起container前面的port,如果你是用自己電腦,ip為127.0.0.1,port是6066
tensorboard --logdir . --port 6066 --bind_all
當中的not found可以忽略不看,可以看到最下會有TensorBoard 版本 at,就代表有起成功。
可以看到訓練的曲線以及acc,那至於有沒有符合預期,我們明天繼續講~
可以再把每個過程看一下,檢查看看你是否懂這塊了。
import argparse
from torch.utils.data import DataLoader
import torchmetrics
import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
from model import MNISTClassifier
from dataloader import CustomDataset
class example(pl.LightningModule):
def __init__(
self,
batch_size = 128,
train_txt = "/ws/code/Day8/train.txt",
val_txt = "/ws/code/Day8/test.txt",
):
super().__init__()
self.batch_size = batch_size
self.train_dataset = CustomDataset(train_txt)
self.val_dataset = CustomDataset(val_txt)
self.model = MNISTClassifier()
self.valid_acc = torchmetrics.classification.Accuracy(task = "multiclass", num_classes = 10)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, batch):
pass
def training_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
loss = self.loss_fn(preds, y).mean()
self.log("train/loss", loss.item(), prog_bar = True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
loss = self.loss_fn(preds, y).mean()
self.log("val/loss", loss.item(), prog_bar = True)
self.valid_acc.update(preds, y)
def on_validation_epoch_end(self):
self.log('valid_acc_epoch', self.valid_acc.compute())
self.valid_acc.reset()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
# lr_scheduler
return optimizer # [optimizer], [lr_scheduler]
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size = self.batch_size,
shuffle = True,
drop_last = True,
num_workers = 4,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size = self.batch_size,
shuffle = False,
drop_last = True,
num_workers = 4,
)
def main(task, max_epochs):
model = example()
if task == 'train':
callbacks = []
dirpath = "./model"
checkpoint_acc = ModelCheckpoint(
save_top_k = 5,
monitor = "valid_acc_epoch",
mode = "max",
dirpath = dirpath,
filename = "model_{epoch:02d}_{valid_acc_epoch:.2f}",
save_last = True,
)
callbacks.append(checkpoint_acc)
trainer = pl.Trainer(
max_epochs = max_epochs,
callbacks = callbacks,
gradient_clip_val = 2.0,
# devices = [0, 1] # 多GPU訓練
)
trainer.fit(model=model)
# elif task == 'ft':
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = "Training script")
parser.add_argument('--task', type = str, default = 'train')
parser.add_argument('--max_epochs', type = int, default = 50)
args = parser.parse_args()
main(args.task, args.max_epochs)
今天就更新到這囉~ 花了很多時間才把最簡單的範例搞定,不過整個架構基本上差不多就這樣了,學會了基礎,接下來就可以往你想要研究的方向去寫一套屬於你的lightning code。
明天會把之前一些沒講清楚的部分說明清楚。